name: mnist-tpu-node

resources:
  instance_type: n1-highmem-8
  accelerators: tpu-v2-8
  accelerator_args:
    runtime_version: 2.12.0
    tpu_vm: False

file_mounts:
  /dataset:
    name: demo-mnist-tpu
    store: gcs
    mode: MOUNT


# The setup command.  Will be run under the working directory.
setup: | 
  git clone https://github.com/tensorflow/models.git

  conda activate mnist
  if [ $? -eq 0 ]; then
    echo 'conda env exists'
  else
    conda create -n mnist python=3.8 -y
    conda activate mnist
    pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
  fi

# The command to run.  Will be run under the working directory.
run: |
  conda activate mnist
  cd models/official/legacy/image_classification/

  export STORAGE_BUCKET=gs://demo-mnist-tpu
  export MODEL_DIR=${STORAGE_BUCKET}/mnist
  export DATA_DIR=${STORAGE_BUCKET}/data

  export PYTHONPATH=/home/gcpuser/sky_workdir/models

  python3 mnist_main.py \
    --tpu=${TPU_NAME} \
    --model_dir=${MODEL_DIR} \
    --data_dir=${DATA_DIR} \
    --train_epochs=10 \
    --distribution_strategy=tpu \
    --download
